{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Matplotlib created a temporary config/cache directory at /tmp/matplotlib-bhsk56tr because the default path (/home/mikyas/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.image as mpimg \n",
    "import trt_pose.coco\n",
    "import math\n",
    "import os\n",
    "import numpy as np\n",
    "import traitlets\n",
    "import sys\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with open('preprocess/hand_pose.json', 'r') as f:\n",
    "    hand_pose = json.load(f)\n",
    "\n",
    "topology = trt_pose.coco.coco_category_to_topology(hand_pose)\n",
    "import trt_pose.models\n",
    "\n",
    "num_parts = len(hand_pose['keypoints'])\n",
    "num_links = len(hand_pose['skeleton'])\n",
    "\n",
    "model = trt_pose.models.resnet18_baseline_att(num_parts, 2 * num_links).cuda().eval()\n",
    "import torch\n",
    "\n",
    "\n",
    "WIDTH = 224\n",
    "HEIGHT = 224\n",
    "data = torch.zeros((1, 3, HEIGHT, WIDTH)).cuda()\n",
    "\n",
    "if not os.path.exists('model/hand_pose_resnet18_att_244_244_trt.pth'):\n",
    "    MODEL_WEIGHTS = 'model/hand_pose_resnet18_att_244_244.pth'\n",
    "    model.load_state_dict(torch.load(MODEL_WEIGHTS))\n",
    "    import torch2trt\n",
    "    model_trt = torch2trt.torch2trt(model, [data], fp16_mode=True, max_workspace_size=1<<25)\n",
    "    OPTIMIZED_MODEL = 'model/hand_pose_resnet18_att_244_244_trt.pth'\n",
    "    torch.save(model_trt.state_dict(), OPTIMIZED_MODEL)\n",
    "\n",
    "\n",
    "OPTIMIZED_MODEL = 'model/hand_pose_resnet18_att_244_244_trt.pth'\n",
    "from torch2trt import TRTModule\n",
    "\n",
    "model_trt = TRTModule()\n",
    "model_trt.load_state_dict(torch.load(OPTIMIZED_MODEL))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from trt_pose.draw_objects import DrawObjects\n",
    "from trt_pose.parse_objects import ParseObjects\n",
    "\n",
    "parse_objects = ParseObjects(topology,cmap_threshold=0.15, link_threshold=0.15)\n",
    "draw_objects = DrawObjects(topology)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_directories_for_classes(no_of_classes, path_dir, dataset_name):\n",
    "    dir_ = os.path.join(path_dir, dataset_name)\n",
    "    for i in range(no_of_classes):\n",
    "        dir_to_create = os.path.join(dir_,\"%s\" % (i+1))\n",
    "        try:\n",
    "            os.makedirs(dir_to_create)\n",
    "        except FileExistsError:\n",
    "            print(os.path.join(\"The following directory was not created because it already exsists\", dir_ , ))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/mikyas/data_collection/hand_dataset_train\n",
      "/home/mikyas/data_collection/hand_dataset_train\n",
      "/home/mikyas/data_collection/hand_dataset_train\n",
      "/home/mikyas/data_collection/hand_dataset_train\n",
      "/home/mikyas/data_collection/hand_dataset_train\n",
      "/home/mikyas/data_collection/hand_dataset_train\n"
     ]
    }
   ],
   "source": [
    "dir_datasets = '/home/mikyas/data_collection/' #give the path to where you want to save you collected data\n",
    "dataset_name = \"hand_dataset_train\" #change this to hand_dataset_test when you are collecting data for test\n",
    "no_of_classes = 6\n",
    "create_directories_for_classes(no_of_classes, dir_datasets, dataset_name )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import ipywidgets.widgets as widgets\n",
    "dir_ = os.path.join(dir_datasets, dataset_name)\n",
    "curr_class_no = 1\n",
    "button_layout = widgets.Layout(width='128px', height='32px')\n",
    "curr_dir = os.path.join(dir_,'%s'%curr_class_no )\n",
    "collecting_button = widgets.Button(description= 'Collect Class ' + str(curr_class_no), button_style='success', layout=button_layout)\n",
    "prev_button = widgets.Button(description='Previous Class', button_style='primary', layout=button_layout)\n",
    "nxt_button = widgets.Button(description='Next Class', button_style='info', layout=button_layout)\n",
    "\n",
    "dir_count = widgets.IntText(layout=button_layout, value=len(os.listdir(curr_dir)))\n",
    "dir_count.continuous_update"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from uuid import uuid1\n",
    "def save_snapshot(directory):\n",
    "    image_path = os.path.join(directory, str(uuid1()) + '.jpg')\n",
    "    with open(image_path, 'wb') as f:\n",
    "        f.write(image_s.value)\n",
    "def save_dir():\n",
    "    global curr_dir, dir_count\n",
    "    save_snapshot(curr_dir)\n",
    "    dir_count.value = len(os.listdir(curr_dir))\n",
    "def prev_dir():\n",
    "    global curr_class_no, curr_dir, no_of_classes\n",
    "    if curr_class_no>1:\n",
    "        curr_class_no-=1\n",
    "    curr_dir = os.path.join(dir_,'%s'%curr_class_no )\n",
    "    collecting_button.description = 'Collect Class ' + str(curr_class_no)\n",
    "    dir_count.value = len(os.listdir(curr_dir))\n",
    "    dir_count.continuous_update\n",
    "def nxt_dir():\n",
    "    global curr_class_no, curr_dir, no_of_classes\n",
    "    if curr_class_no<no_of_classes:\n",
    "        curr_class_no+=1\n",
    "    curr_dir = os.path.join(dir_,'%s'%curr_class_no )\n",
    "    collecting_button.description = 'Collect Class ' + str(curr_class_no)\n",
    "    dir_count.value = len(os.listdir(curr_dir))\n",
    "\n",
    "        \n",
    "\n",
    "collecting_button.on_click(lambda x: save_dir())\n",
    "nxt_button.on_click(lambda x: nxt_dir())\n",
    "prev_button.on_click(lambda x: prev_dir())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import torchvision.transforms as transforms\n",
    "import PIL.Image\n",
    "\n",
    "mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()\n",
    "std = torch.Tensor([0.229, 0.224, 0.225]).cuda()\n",
    "device = torch.device('cuda')\n",
    "\n",
    "def preprocess(image):\n",
    "    global device\n",
    "    device = torch.device('cuda')\n",
    "    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
    "    image = PIL.Image.fromarray(image)\n",
    "    image = transforms.functional.to_tensor(image).to(device)\n",
    "    image.sub_(mean[:, None, None]).div_(std[:, None, None])\n",
    "    return image[None, ...]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from preprocessdata import preprocessdata\n",
    "preprocessdata = preprocessdata(topology, num_parts)\n",
    "from gesture_classifier import gesture_classifier\n",
    "gesture_classifier = gesture_classifier()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's define a function that will preprocess the image, which is originally in BGR8 / HWC format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_joints(image, joints):\n",
    "    count = 0\n",
    "    for i in joints:\n",
    "        if i==[0,0]:\n",
    "            count+=1\n",
    "    if count>= 3:\n",
    "        return \n",
    "    for i in joints:\n",
    "        cv2.circle(image, (i[0],i[1]), 2, (0,0,255), 1)\n",
    "    cv2.circle(image, (joints[0][0],joints[0][1]), 2, (255,0,255), 1)\n",
    "    for i in hand_pose['skeleton']:\n",
    "        if joints[i[0]-1][0]==0 or joints[i[1]-1][0] == 0:\n",
    "            break\n",
    "        cv2.line(image, (joints[i[0]-1][0],joints[i[0]-1][1]), (joints[i[1]-1][0],joints[i[1]-1][1]), (0,255,0), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jetcam.usb_camera import USBCamera\n",
    "from jetcam.csi_camera import CSICamera\n",
    "from jetcam.utils import bgr8_to_jpeg\n",
    "\n",
    "camera = USBCamera(width=WIDTH, height=HEIGHT, capture_fps=30, capture_device=1)\n",
    "#camera = CSICamera(width=WIDTH, height=HEIGHT, capture_fps=30)\n",
    "\n",
    "camera.running = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ipywidgets\n",
    "from IPython.display import display\n",
    "\n",
    "\n",
    "image_w = ipywidgets.Image(format='jpeg', width=224, height=224)\n",
    "image_s = ipywidgets.Image(format='jpeg', width=224, height=224)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "375b9e6d61384dc1beb1b7cc25a0666b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Image(value=b'', format='jpeg', height='224', width='224')"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "380b01ad84b340c1a9342ccea6ad1b46",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntText(value=0, layout=Layout(height='32px', width='128px')), Button(button_style='success', d…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d954768e10cf46fb963d60c733baf892",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(Button(button_style='info', description='Next Class', layout=Layout(height='32px', width='128px…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "13412c28b4ca46a7b41ab27decca0214",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(Button(button_style='primary', description='Previous Class', layout=Layout(height='32px', width…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display(image_w)\n",
    "display(widgets.HBox([dir_count, collecting_button]))\n",
    "display(widgets.HBox([ nxt_button]))\n",
    "display(widgets.HBox([ prev_button]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def execute(change):\n",
    "    image = change['new']\n",
    "    image_s.value = bgr8_to_jpeg(image[:, ::-1, :])\n",
    "    data = preprocess(image)\n",
    "    cmap, paf = model_trt(data)\n",
    "    cmap, paf = cmap.detach().cpu(), paf.detach().cpu()\n",
    "    counts, objects, peaks = parse_objects(cmap, paf)\n",
    "    joints = preprocessdata.joints_inference(image, counts, objects, peaks)\n",
    "    draw_joints(image, joints)\n",
    "    #draw_objects(image, counts, objects, peaks)# try this for multiple hand pose prediction \n",
    "    image_w.value = bgr8_to_jpeg(image[:, ::-1, :])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "execute({'new': camera.value})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "camera.observe(execute, names='value')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "camera.unobserve_all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "#camera.running = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_labels(dir_, dataset_name):\n",
    "    labels = []\n",
    "    starting_label = 1\n",
    "    for i in range(len(os.listdir(dir_))):\n",
    "        dir_to_check = os.path.join(dir_,\"%s\" % (i+1))\n",
    "        for j in range(len(os.listdir(dir_to_check))):\n",
    "            labels.append(starting_label)\n",
    "        starting_label+=1\n",
    "    labels_to_dict = {\"labels\": labels}\n",
    "    with open((dir_+'.json'), 'w') as outfile:\n",
    "        json.dump(labels_to_dict, outfile)\n",
    "    return labels      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rename_images(dir_):\n",
    "    overall_count = 0\n",
    "    #dir_ = dir_+dataset_name\n",
    "    for i in range(len(os.listdir(dir_))):\n",
    "        dir_to_check = os.path.join(dir_,\"%s\" % (i+1))\n",
    "        dir_to_check+='/'\n",
    "        for count, filename in enumerate(os.listdir(dir_to_check)):\n",
    "            dst = \"%08d.jpg\"% overall_count\n",
    "            src = dir_to_check+filename\n",
    "            dst = dir_to_check+dst \n",
    "            os.rename(src, dst)\n",
    "            overall_count+=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_labels(dir_, dataset_name)\n",
    "rename_images(dir_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/home/mikyas/chitoku/data_collection/hand_dataset_test/hand_dataset_test.json'"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import shutil\n",
    "dir_training = dir_datasets +'/training/'#change this to /test/ when you are collecting data for test\n",
    "try:\n",
    "    os.makedirs(dir_training)\n",
    "except FileExistsError:\n",
    "    print(os.path.join(\"The following directory was not created because it already exsists\", dir_ , ))\n",
    "for i in range(len(os.listdir(dir_))):\n",
    "    dir_to_check = os.path.join(dir_,\"%s\" % (i+1))+'/'\n",
    "    for count, filename in enumerate(os.listdir(dir_to_check)):\n",
    "            src = dir_to_check+filename\n",
    "            shutil.move(src,dir_training)\n",
    "    os.rmdir(dir_to_check)\n",
    "shutil.move(dir_training,dir_)\n",
    "shutil.move(dir_+'.json',dir_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
